#!/usr/bin/env python
# -*- coding: utf-8 -*-

import rospy
import numpy as np
import random
from collections import defaultdict
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
from src.turtlebot3_dqn.environment3 import Env
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

#hyperparameter
BATCH_SIZE = 32
LR = 0.01         #学习率
EPSILON = 0.9     #选择最优动作的贪婪概率
GAMMA = 0.9       #奖励递减参数
TARGET_REPLACE_ITER = 100 #Q现实网络的更新频率，每100次循环更新一次
MEMORY_CAPACITY = 2000 #记忆库大小
N_ACTIONS = 8     #智能体的动作0,1,2,。。。，7
N_STATES = 1

def trans_torch(list1):
    list1 = np.array(list1)
    l1 = np.where(list1==1,1,0)
    l2 = np.where(list1==2,1,0)
    l3 = np.where(list1==3,1,0)
    b = np.array([l1,l2,l3])
    return b

#netural network
class DQNNet(nn.Module):
    def __init__(self):
        super(DQNNet,self).__init__()
        self.c1=nn.Conv2d(3,17*17,5,2,0)
        self.c2=nn.Conv2d(17*17,9*9,5,1,0)
        self.c3=nn.Conv2d(9*9,25,5,1,0)
	self.f1 = nn.Linear(25,64)
	self.f1.weight.data.normal_(0,0.1)
	self.f2 = nn.Linear(64,8)
	self.f2.weight.data.normal_(0,0.1)
    def forward(self,x):
	x = self.c1(x)
	x = F.relu(x)
        x=self.c2(x)
        x=F.relu(x)
        x=self.c3(x)
        x=F.relu(x)
	x = x.view(x.size(0),-1) #在CNN中卷积或者池化之后需要连接全连接层，所以需要把多维度的tensor展平成一维，x.view(x.size(0), -1)就实现的这个功能
	x = self.f1(x)
	x = F.relu(x)
	action = self.f2(x)
	return action



class DQLearningAgent(object):
    def __init__(self):
        # actions = [0, 1, 2, 3]
	self.eval_net,self.target_net = DQNNet(),DQNNet()
        #eval为Q估计神经网络 target为Q现实神经网络
        self.learn_step_counter = 0 # 用于 target 更新计时，100次更新一次
        self.memory_counter = 0 # 记忆库记数
        self.memory = list(np.zeros((MEMORY_CAPACITY, 8))) # 初始化记忆库用numpy生成一个(2000,4)大小的全0矩阵
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) # torch 的优化器
        self.loss_func = nn.MSELoss()   # 误差公式

    def choose_action(self, x):
	x = torch.unsqueeze(torch.FloatTensor(x),0)
        if np.random.uniform() < EPSILON:
            actions_value = self.eval_net.forward(x)#将场景输入Q估计神经网络
	    action = torch.max(actions_value, 1)[1].data.numpy()#返回动作最大值
        else: #选择随机动作
	    action = np.array([np.random.randint(0,N_ACTIONS)])
        return action
    def store_transition(self, s, a, r, s_):
	#如果记忆库满了，用新的覆盖老的数据
	index = self.memory_counter % MEMORY_CAPACITY
	self.memory[index] = [s, a, r, s_]
	self.memory_counter +=1

    def learn(self):
	#target_net 参数更新，每100次
	if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
	    #将所有的eval_net里面的参数复制到target_net里面
	    self.target_net.load_state_dict(self.eval_net.state_dict())
	self.learn_step_counter += 1
	#抽取记忆库中的32个批数据 标签
	sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
	b_s = []
	b_a = []
	b_r = []
	b_s_ = []
	for i in sample_index:
	    b_s.append(self.memory[i][0])
            b_a.append(np.array(self.memory[i][1],dtype=np.int32))
            b_r.append(np.array([self.memory[i][2]],dtype=np.int32))
            b_s_.append(self.memory[i][3])
        b_s = torch.FloatTensor(b_s)#取出s
        b_a = torch.LongTensor(b_a) #取出a
        b_r = torch.FloatTensor(b_r) #取出r
        b_s_ = torch.FloatTensor(b_s_) #取出s_
        # 针对做过的动作b_a, 来选 q_eval 的值, (q_eval 原本有所有动作的值)
        q_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1) 找到action的Q估计(关于gather使用下面有介绍)
        q_next = self.target_net(b_s_).detach()     # q_next 不进行反向传递误差, 所以 detach Q现实
        q_target = b_r + GAMMA * q_next.max(1)[0]   # shape (batch, 1) DQL核心公式
        loss = self.loss_func(q_eval, q_target) #计算误差
        # 计算, 更新 eval net
        self.optimizer.zero_grad() #
        loss.backward() #反向传递
        self.optimizer.step()

if __name__ == "__main__":
    rospy.init_node('turtlebot3_DQLearning_stage_2')
    env = Env()
    agent = DQLearningAgent()
    study = 1

    for episode in range(1000):
	states = env.reset()
        state = env.start_env()
	state = trans_torch(state)
        while True:
            env.render()

            # take action and proceed one step in the environment
            action = agent.choose_action(state)
            next_state, reward, done = env.step(action)
	    next_state = trans_torch(next_state)

	    agent.store_transition(state, action, reward, next_state)
	    if agent.memory_counter > MEMORY_CAPACITY:
		if study == 1:
		    print("2000大小的经验池已满")
                agent.learn()

            state = next_state

            # if episode ends, then break
            if done:
                break
